// SPDX-FileCopyrightText: 2025 kvcache-ai
// SPDX-FileCopyrightText: 2025 Qingcheng.AI
//
// SPDX-License-Identifier: Apache-2.0

/**
 * This file has adaption of open-source code from the following sources:
 * - https://github.com/kvcache-ai/ktransformers, licensed under Apache 2.0.
 */

#include "moe.h"
#include <cstdint>
#include <iostream>

MOE::MOE(MOEConfig config) {
    config_ = config;
    gate_proj_ = config_.gate_proj;
    up_proj_ = config_.up_proj;
    down_proj_ = config_.down_proj;

    std::vector<std::pair<void **, uint64_t>> s_mem_requests;
    s_mem_requests.push_back(
        {(void **)&s_input_fp32_, sizeof(float) * config_.hidden_size});
    s_mem_requests.push_back(
        {(void **)&s_gate_input_,
         config_.hidden_size *
             ggml_type_size(ggml_internal_get_type_traits(config_.gate_type)
                                .vec_dot_type) /
             ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type)
                                .vec_dot_type)});
    s_mem_requests.push_back(
        {(void **)&s_up_input_,
         config_.hidden_size *
             ggml_type_size(
                 ggml_internal_get_type_traits(config_.up_type).vec_dot_type) /
             ggml_blck_size(
                 ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
    s_gate_output_.resize(config_.routed_expert_num);
    s_up_output_.resize(config_.routed_expert_num);
    s_intermediate_fp32_.resize(config_.routed_expert_num);
    s_down_input_.resize(config_.routed_expert_num);
    s_down_output_.resize(config_.routed_expert_num);
    for (int i = 0; i < config_.routed_expert_num; i++) {
        s_mem_requests.push_back({(void **)&s_gate_output_[i],
                                  sizeof(float) * config_.intermediate_size});
        s_mem_requests.push_back({(void **)&s_up_output_[i],
                                  sizeof(float) * config_.intermediate_size});
        s_mem_requests.push_back({(void **)&s_intermediate_fp32_[i],
                                  sizeof(float) * config_.intermediate_size});
        s_mem_requests.push_back(
            {(void **)&s_down_input_[i],
             config_.intermediate_size *
                 ggml_type_size(ggml_internal_get_type_traits(config_.down_type)
                                    .vec_dot_type) /
                 ggml_blck_size(ggml_internal_get_type_traits(config_.down_type)
                                    .vec_dot_type)});
        s_mem_requests.push_back(
            {(void **)&s_down_output_[i], sizeof(float) * config_.hidden_size});
    }
    s_mem_requests.push_back(
        {(void **)&s_output_fp32_, sizeof(float) * config_.hidden_size});
    shared_mem_buffer.alloc(this, s_mem_requests);

    std::vector<std::pair<void **, uint64_t>> m_mem_requests;
    m_input_fp32_.resize(config_.group_max_len);
    m_gate_input_.resize(config_.group_max_len);
    m_up_input_.resize(config_.group_max_len);
    for (int i = 0; i < config_.group_max_len; i++) {
        m_mem_requests.push_back(
            {(void **)&m_input_fp32_[i], sizeof(float) * config_.hidden_size});
        m_mem_requests.push_back(
            {(void **)&m_gate_input_[i],
             config_.hidden_size *
                 ggml_type_size(ggml_internal_get_type_traits(config_.gate_type)
                                    .vec_dot_type) /
                 ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type)
                                    .vec_dot_type)});
        m_mem_requests.push_back(
            {(void **)&m_up_input_[i],
             config_.hidden_size *
                 ggml_type_size(ggml_internal_get_type_traits(config_.up_type)
                                    .vec_dot_type) /
                 ggml_blck_size(ggml_internal_get_type_traits(config_.up_type)
                                    .vec_dot_type)});
    }
    m_mem_requests.push_back(
        {(void **)&m_local_gate_input_,
         config_.routed_expert_num * config_.group_max_len *
             config_.hidden_size *
             ggml_type_size(ggml_internal_get_type_traits(config_.gate_type)
                                .vec_dot_type) /
             ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type)
                                .vec_dot_type)});
    m_mem_requests.push_back(
        {(void **)&m_local_up_input_,
         config_.routed_expert_num * config_.group_max_len *
             config_.hidden_size *
             ggml_type_size(
                 ggml_internal_get_type_traits(config_.up_type).vec_dot_type) /
             ggml_blck_size(
                 ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
    m_mem_requests.push_back(
        {(void **)&m_local_gate_output_,
         sizeof(float) * config_.routed_expert_num * config_.group_max_len *
             config_.intermediate_size});
    m_mem_requests.push_back(
        {(void **)&m_local_up_output_,
         sizeof(float) * config_.routed_expert_num * config_.group_max_len *
             config_.intermediate_size});
    m_mem_requests.push_back(
        {(void **)&m_local_intermediate_fp32_,
         sizeof(float) * config_.routed_expert_num * config_.group_max_len *
             config_.intermediate_size});
    m_mem_requests.push_back(
        {(void **)&m_local_down_input_,
         config_.routed_expert_num * config_.group_max_len *
             config_.intermediate_size *
             ggml_type_size(ggml_internal_get_type_traits(config_.down_type)
                                .vec_dot_type) /
             ggml_blck_size(ggml_internal_get_type_traits(config_.down_type)
                                .vec_dot_type)});
    m_mem_requests.push_back({(void **)&m_local_down_output_,
                              sizeof(float) * config_.routed_expert_num *
                                  config_.group_max_len * config_.hidden_size});
    m_output_fp32_.resize(config_.group_max_len);
    for (int i = 0; i < config_.group_max_len; i++) {
        m_mem_requests.push_back(
            {(void **)&m_output_fp32_[i], sizeof(float) * config_.hidden_size});
    }
    shared_mem_buffer.alloc(this, m_mem_requests);

    m_local_pos_.resize(config_.group_max_len);
    for (int i = 0; i < config_.group_max_len; i++) {
        m_local_pos_[i].resize(config_.routed_expert_num);
    }
    m_local_num_.resize(config_.expert_num);
    m_local_gate_input_ptr_.resize(config_.expert_num);
    m_local_up_input_ptr_.resize(config_.expert_num);
    m_local_gate_output_ptr_.resize(config_.expert_num);
    m_local_up_output_ptr_.resize(config_.expert_num);
    m_local_intermediate_fp32_ptr_.resize(config_.expert_num);
    m_local_down_input_ptr_.resize(config_.expert_num);
    m_local_down_output_ptr_.resize(config_.expert_num);
}

MOE::~MOE() { shared_mem_buffer.dealloc(this); }

void MOE::warm_up(CPUInfer *cpuinfer) {
    std::vector<float> input_fp32(config_.hidden_size);
    std::vector<uint8_t> input(config_.hidden_size *
                               ggml_type_size(config_.hidden_type) /
                               ggml_blck_size(config_.hidden_type));
    std::vector<uint8_t> output(config_.hidden_size *
                                ggml_type_size(config_.hidden_type) /
                                ggml_blck_size(config_.hidden_type));
    for (int i = 0; i < config_.hidden_size; i++) {
        input_fp32[i] = 0;
    }
    from_float(input_fp32.data(), input.data(), config_.hidden_size,
               config_.hidden_type);
    for (int i = 0; i < config_.expert_num; i++) {
        uint64_t expert_ids = i;
        float weights = 0;
        forward_one(1, &expert_ids, &weights, input.data(), output.data(),
                    cpuinfer);
    }
}

static float act_fn(float x) { return x / (1.0f + expf(-x)); }

void MOE::forward_one(int k, const uint64_t *expert_ids, const float *weights,
                      const void *input, void *output, CPUInfer *cpuinfer) {
    const void *gate_input_ptr;
    const void *up_input_ptr;
    if (config_.hidden_type ==
            ggml_internal_get_type_traits(config_.gate_type).vec_dot_type &&
        config_.hidden_type ==
            ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
        gate_input_ptr = up_input_ptr = input;
    } else {
        to_float(input, s_input_fp32_, config_.hidden_size,
                 config_.hidden_type);
        if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type ==
            ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
            from_float(
                s_input_fp32_, s_gate_input_, config_.hidden_size,
                ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
            gate_input_ptr = up_input_ptr = s_gate_input_;
        } else {
            if (config_.hidden_type !=
                ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {
                from_float(s_input_fp32_, s_gate_input_, config_.hidden_size,
                           ggml_internal_get_type_traits(config_.gate_type)
                               .vec_dot_type);
                gate_input_ptr = s_gate_input_;
            } else {
                gate_input_ptr = input;
            }
            if (config_.hidden_type !=
                ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
                from_float(s_input_fp32_, s_up_input_, config_.hidden_size,
                           ggml_internal_get_type_traits(config_.up_type)
                               .vec_dot_type);
                up_input_ptr = s_up_input_;
            } else {
                up_input_ptr = input;
            }
        }
    }
    int nth = config_.intermediate_size / config_.stride;
    cpuinfer->parallel_for(nth * k, [&](int task_id) {
        int expert_idx = task_id / nth;
        uint64_t expert_id = expert_ids[expert_idx];
        int ith = task_id % nth;

        void *gate_proj_ptr =
            (uint8_t *)gate_proj_ +
            (expert_id * config_.intermediate_size + ith * config_.stride) *
                config_.hidden_size * ggml_type_size(config_.gate_type) /
                ggml_blck_size(config_.gate_type);

        float *gate_output_ptr =
            s_gate_output_[expert_idx] + ith * config_.stride;
        llamafile_sgemm(
            config_.stride, 1,
            config_.hidden_size / ggml_blck_size(config_.gate_type),
            gate_proj_ptr,
            config_.hidden_size / ggml_blck_size(config_.gate_type),
            gate_input_ptr,
            config_.hidden_size / ggml_blck_size(config_.gate_type),
            gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE,
            config_.gate_type,
            ggml_internal_get_type_traits(config_.gate_type).vec_dot_type,
            GGML_TYPE_F32, GGML_PREC_DEFAULT);

        void *up_proj_ptr =
            (uint8_t *)up_proj_ +
            (expert_id * config_.intermediate_size + ith * config_.stride) *
                config_.hidden_size * ggml_type_size(config_.up_type) /
                ggml_blck_size(config_.up_type);

        float *up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
        llamafile_sgemm(
            config_.stride, 1,
            config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr,
            config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr,
            config_.hidden_size / ggml_blck_size(config_.up_type),
            up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE,
            config_.up_type,
            ggml_internal_get_type_traits(config_.up_type).vec_dot_type,
            GGML_TYPE_F32, GGML_PREC_DEFAULT);
        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride;
             i++) {
            s_intermediate_fp32_[expert_idx][i] =
                act_fn(s_gate_output_[expert_idx][i]) *
                s_up_output_[expert_idx][i];
        }
        if (config_.stride %
                ggml_blck_size(ggml_internal_get_type_traits(config_.down_type)
                                   .vec_dot_type) ==
            0) {
            float *intermediate_fp32_ptr =
                s_intermediate_fp32_[expert_idx] + ith * config_.stride;
            void *down_input_ptr =
                s_down_input_[expert_idx] +
                ith * config_.stride *
                    ggml_type_size(
                        ggml_internal_get_type_traits(config_.down_type)
                            .vec_dot_type) /
                    ggml_blck_size(
                        ggml_internal_get_type_traits(config_.down_type)
                            .vec_dot_type);
            from_float(
                intermediate_fp32_ptr, down_input_ptr, config_.stride,
                ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
        }
    });
    if (config_.stride %
            ggml_blck_size(ggml_internal_get_type_traits(config_.down_type)
                               .vec_dot_type) !=
        0) {
        for (int i = 0; i < k; i++) {
            from_float(
                s_intermediate_fp32_[i], s_down_input_[i],
                config_.intermediate_size,
                ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
        }
    }
    nth = config_.hidden_size / config_.stride;
    cpuinfer->parallel_for(nth, [&](int task_id) {
        int ith = task_id;
        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride;
             i++) {
            s_output_fp32_[i] = 0;
        }
        for (int expert_idx = 0; expert_idx < k; expert_idx++) {
            uint64_t expert_id = expert_ids[expert_idx];

            void *down_proj_ptr =
                (uint8_t *)down_proj_ +
                (expert_id * config_.hidden_size + ith * config_.stride) *
                    config_.intermediate_size *
                    ggml_type_size(config_.down_type) /
                    ggml_blck_size(config_.down_type);

            float *down_output_ptr =
                s_down_output_[expert_idx] + ith * config_.stride;
            llamafile_sgemm(
                config_.stride, 1,
                config_.intermediate_size / ggml_blck_size(config_.down_type),
                down_proj_ptr,
                config_.intermediate_size / ggml_blck_size(config_.down_type),
                s_down_input_[expert_idx],
                config_.intermediate_size / ggml_blck_size(config_.down_type),
                down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE,
                config_.down_type,
                ggml_internal_get_type_traits(config_.down_type).vec_dot_type,
                GGML_TYPE_F32, GGML_PREC_DEFAULT);
            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride;
                 i++) {
                s_output_fp32_[i] +=
                    s_down_output_[expert_idx][i] * weights[expert_idx];
            }
        }
        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
            float *output_fp32_ptr = s_output_fp32_ + ith * config_.stride;
            void *output_ptr =
                (uint8_t *)output + ith * config_.stride *
                                        ggml_type_size(config_.hidden_type) /
                                        ggml_blck_size(config_.hidden_type);
            from_float(output_fp32_ptr, output_ptr, config_.stride,
                       config_.hidden_type);
        }
    });
    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
        from_float(s_output_fp32_, output, config_.hidden_size,
                   config_.hidden_type);
    }
}

void MOE::forward_many(int qlen, int k, const uint64_t *expert_ids,
                       const float *weights, const void *input, void *output,
                       CPUInfer *cpuinfer) {
    for (int i = 0; i < config_.expert_num; i++) {
        m_local_num_[i] = 0;
    }
    for (int i = 0; i < qlen; i++) {
        for (int j = 0; j < k; j++) {
            m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
        }
    }
    uint64_t offset = 0;
    for (int i = 0; i < config_.expert_num; i++) {
        m_local_gate_input_ptr_[i] =
            m_local_gate_input_ +
            offset * config_.hidden_size *
                ggml_type_size(ggml_internal_get_type_traits(config_.gate_type)
                                   .vec_dot_type) /
                ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type)
                                   .vec_dot_type);
        m_local_up_input_ptr_[i] =
            m_local_up_input_ +
            offset * config_.hidden_size *
                ggml_type_size(ggml_internal_get_type_traits(config_.up_type)
                                   .vec_dot_type) /
                ggml_blck_size(ggml_internal_get_type_traits(config_.up_type)
                                   .vec_dot_type);
        m_local_gate_output_ptr_[i] =
            m_local_gate_output_ + offset * config_.intermediate_size;
        m_local_up_output_ptr_[i] =
            m_local_up_output_ + offset * config_.intermediate_size;
        m_local_intermediate_fp32_ptr_[i] =
            m_local_intermediate_fp32_ + offset * config_.intermediate_size;
        m_local_down_input_ptr_[i] =
            m_local_down_input_ +
            offset * config_.intermediate_size *
                ggml_type_size(ggml_internal_get_type_traits(config_.down_type)
                                   .vec_dot_type) /
                ggml_blck_size(ggml_internal_get_type_traits(config_.down_type)
                                   .vec_dot_type);
        m_local_down_output_ptr_[i] =
            m_local_down_output_ + offset * config_.hidden_size;
        offset += m_local_num_[i];
    }
    cpuinfer->parallel_for(qlen, [&](int i) {
        const void *gate_input_ptr;
        const void *up_input_ptr;
        if (config_.hidden_type ==
                ggml_internal_get_type_traits(config_.gate_type).vec_dot_type &&
            config_.hidden_type ==
                ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
            gate_input_ptr = up_input_ptr =
                (uint8_t *)input + i * config_.hidden_size *
                                       ggml_type_size(config_.hidden_type) /
                                       ggml_blck_size(config_.hidden_type);
        } else {
            to_float(
                (uint8_t *)input + i * config_.hidden_size *
                                       ggml_type_size(config_.hidden_type) /
                                       ggml_blck_size(config_.hidden_type),
                m_input_fp32_[i], config_.hidden_size, config_.hidden_type);
            if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type ==
                ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
                from_float(m_input_fp32_[i], m_gate_input_[i],
                           config_.hidden_size,
                           ggml_internal_get_type_traits(config_.gate_type)
                               .vec_dot_type);
                gate_input_ptr = up_input_ptr = m_gate_input_[i];
            } else {
                if (config_.hidden_type !=
                    ggml_internal_get_type_traits(config_.gate_type)
                        .vec_dot_type) {
                    from_float(m_input_fp32_[i], m_gate_input_[i],
                               config_.hidden_size,
                               ggml_internal_get_type_traits(config_.gate_type)
                                   .vec_dot_type);
                    gate_input_ptr = m_gate_input_[i];
                } else {
                    gate_input_ptr = (uint8_t *)input +
                                     i * config_.hidden_size *
                                         ggml_type_size(config_.hidden_type) /
                                         ggml_blck_size(config_.hidden_type);
                }
                if (config_.hidden_type !=
                    ggml_internal_get_type_traits(config_.up_type)
                        .vec_dot_type) {
                    from_float(m_input_fp32_[i], m_up_input_[i],
                               config_.hidden_size,
                               ggml_internal_get_type_traits(config_.up_type)
                                   .vec_dot_type);
                    up_input_ptr = m_up_input_[i];
                } else {
                    up_input_ptr = (uint8_t *)input +
                                   i * config_.hidden_size *
                                       ggml_type_size(config_.hidden_type) /
                                       ggml_blck_size(config_.hidden_type);
                }
            }
        }
        for (int j = 0; j < k; j++) {
            memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] +
                       m_local_pos_[i][j] * config_.hidden_size *
                           ggml_type_size(
                               ggml_internal_get_type_traits(config_.gate_type)
                                   .vec_dot_type) /
                           ggml_blck_size(
                               ggml_internal_get_type_traits(config_.gate_type)
                                   .vec_dot_type),
                   gate_input_ptr,
                   config_.hidden_size *
                       ggml_type_size(
                           ggml_internal_get_type_traits(config_.gate_type)
                               .vec_dot_type) /
                       ggml_blck_size(
                           ggml_internal_get_type_traits(config_.gate_type)
                               .vec_dot_type));
            memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] +
                       m_local_pos_[i][j] * config_.hidden_size *
                           ggml_type_size(
                               ggml_internal_get_type_traits(config_.up_type)
                                   .vec_dot_type) /
                           ggml_blck_size(
                               ggml_internal_get_type_traits(config_.up_type)
                                   .vec_dot_type),
                   up_input_ptr,
                   config_.hidden_size *
                       ggml_type_size(
                           ggml_internal_get_type_traits(config_.up_type)
                               .vec_dot_type) /
                       ggml_blck_size(
                           ggml_internal_get_type_traits(config_.up_type)
                               .vec_dot_type));
        }
    });
    int stride = QK_K;
    int nth = config_.intermediate_size / stride;
    cpuinfer->parallel_for(nth * config_.expert_num, [&](int task_id) {
        uint64_t expert_idx = task_id / nth;
        int ith = task_id % nth;
        void *gate_input_ptr = m_local_gate_input_ptr_[expert_idx];

        void *gate_proj_ptr =
            (uint8_t *)gate_proj_ +
            (expert_idx * config_.intermediate_size + ith * stride) *
                config_.hidden_size * ggml_type_size(config_.gate_type) /
                ggml_blck_size(config_.gate_type);

        float *gate_output_ptr =
            m_local_gate_output_ptr_[expert_idx] + ith * stride;
        llamafile_sgemm(
            stride, m_local_num_[expert_idx],
            config_.hidden_size / ggml_blck_size(config_.gate_type),
            gate_proj_ptr,
            config_.hidden_size / ggml_blck_size(config_.gate_type),
            gate_input_ptr,
            config_.hidden_size / ggml_blck_size(config_.gate_type),
            gate_output_ptr, config_.intermediate_size, 0, 1,
            GGML_TASK_TYPE_COMPUTE, config_.gate_type,
            ggml_internal_get_type_traits(config_.gate_type).vec_dot_type,
            GGML_TYPE_F32, GGML_PREC_DEFAULT);
        void *up_input_ptr = m_local_up_input_ptr_[expert_idx];

        void *up_proj_ptr =
            (uint8_t *)up_proj_ +
            (expert_idx * config_.intermediate_size + ith * stride) *
                config_.hidden_size * ggml_type_size(config_.up_type) /
                ggml_blck_size(config_.up_type);

        float *up_output_ptr =
            m_local_up_output_ptr_[expert_idx] + ith * stride;
        llamafile_sgemm(
            stride, m_local_num_[expert_idx],
            config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr,
            config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr,
            config_.hidden_size / ggml_blck_size(config_.up_type),
            up_output_ptr, config_.intermediate_size, 0, 1,
            GGML_TASK_TYPE_COMPUTE, config_.up_type,
            ggml_internal_get_type_traits(config_.up_type).vec_dot_type,
            GGML_TYPE_F32, GGML_PREC_DEFAULT);
        for (int i = 0; i < m_local_num_[expert_idx]; i++) {
            for (int j = ith * stride; j < (ith + 1) * stride; j++) {
                m_local_intermediate_fp32_ptr_
                    [expert_idx][i * config_.intermediate_size + j] =
                        act_fn(m_local_gate_output_ptr_
                                   [expert_idx]
                                   [i * config_.intermediate_size + j]) *
                        m_local_up_output_ptr_[expert_idx]
                                              [i * config_.intermediate_size +
                                               j];
            }
            float *intermediate_fp32_ptr =
                m_local_intermediate_fp32_ptr_[expert_idx] +
                i * config_.intermediate_size + ith * stride;
            void *down_input_ptr =
                m_local_down_input_ptr_[expert_idx] +
                i * config_.intermediate_size *
                    ggml_type_size(
                        ggml_internal_get_type_traits(config_.down_type)
                            .vec_dot_type) /
                    ggml_blck_size(
                        ggml_internal_get_type_traits(config_.down_type)
                            .vec_dot_type) +
                ith * stride *
                    ggml_type_size(
                        ggml_internal_get_type_traits(config_.down_type)
                            .vec_dot_type) /
                    ggml_blck_size(
                        ggml_internal_get_type_traits(config_.down_type)
                            .vec_dot_type);
            from_float(
                intermediate_fp32_ptr, down_input_ptr, stride,
                ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
        }
    });
    stride = QK_K;
    nth = config_.hidden_size / stride;
    cpuinfer->parallel_for(nth * config_.expert_num, [&](int task_id) {
        uint64_t expert_idx = task_id / nth;
        int ith = task_id % nth;
        void *down_input_ptr = m_local_down_input_ptr_[expert_idx];

        void *down_proj_ptr =
            (uint8_t *)down_proj_ +
            (expert_idx * config_.hidden_size + ith * stride) *
                config_.intermediate_size * ggml_type_size(config_.down_type) /
                ggml_blck_size(config_.down_type);

        float *down_output_ptr =
            m_local_down_output_ptr_[expert_idx] + ith * stride;
        llamafile_sgemm(
            stride, m_local_num_[expert_idx],
            config_.intermediate_size / ggml_blck_size(config_.down_type),
            down_proj_ptr,
            config_.intermediate_size / ggml_blck_size(config_.down_type),
            down_input_ptr,
            config_.intermediate_size / ggml_blck_size(config_.down_type),
            down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE,
            config_.down_type,
            ggml_internal_get_type_traits(config_.down_type).vec_dot_type,
            GGML_TYPE_F32, GGML_PREC_DEFAULT);
    });
    cpuinfer->parallel_for(qlen, [&](int i) {
        for (int e = 0; e < config_.hidden_size; e++) {
            m_output_fp32_[i][e] = 0;
        }
        for (int j = 0; j < k; j++) {
            for (int e = 0; e < config_.hidden_size; e++) {
                m_output_fp32_[i][e] +=
                    m_local_down_output_ptr_[expert_ids[i * k + j]]
                                            [m_local_pos_[i][j] *
                                                 config_.hidden_size +
                                             e] *
                    weights[i * k + j];
            }
        }
        from_float(m_output_fp32_[i],
                   (uint8_t *)output + i * config_.hidden_size *
                                           ggml_type_size(config_.hidden_type) /
                                           ggml_blck_size(config_.hidden_type),
                   config_.hidden_size, config_.hidden_type);
    });
}

void MOE::forward(int qlen, int k, const uint64_t *expert_ids,
                  const float *weights, const void *input, void *output,
                  CPUInfer *cpuinfer) {
    if (qlen < config_.group_min_len) {
        for (int i = 0; i < qlen; i++) {
            forward_one(
                k, expert_ids + i * k, weights + i * k,
                (uint8_t *)input + i * config_.hidden_size *
                                       ggml_type_size(config_.hidden_type) /
                                       ggml_blck_size(config_.hidden_type),
                (uint8_t *)output + i * config_.hidden_size *
                                        ggml_type_size(config_.hidden_type) /
                                        ggml_blck_size(config_.hidden_type),
                cpuinfer);
        }
        return;
    }
    int forward_len = std::min(config_.group_max_len, qlen);
    forward_many(forward_len, k, expert_ids, weights, input, output, cpuinfer);
    forward(qlen - forward_len, k, expert_ids + forward_len * k,
            weights + forward_len * k,
            (uint8_t *)input + forward_len * config_.hidden_size *
                                   ggml_type_size(config_.hidden_type) /
                                   ggml_blck_size(config_.hidden_type),
            (uint8_t *)output + forward_len * config_.hidden_size *
                                    ggml_type_size(config_.hidden_type) /
                                    ggml_blck_size(config_.hidden_type),
            cpuinfer);
}
